"""
Takes code from
https://github.com/deepmind/acme/acme/agents/jax/ail/networks.py
to adopt adversarial imitation learning's regularized discriminator network.
"""

from typing import Optional, NamedTuple, Iterable, Callable
import jax.numpy as jnp
import jax
import haiku as hk
import matplotlib.pyplot as plt
import numpy as np
import optax
import functools
from tqdm import tqdm

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from scipy.stats import norm


class SpectralNormalizedLinear(hk.Module):
    """SpectralNormalizedLinear module.
    This is a Linear layer with a upper-bounded Lipschitz. It is used in iResNet.
    Reference:
      Behrmann et al. Invertible Residual Networks. ICML 2019.
      https://arxiv.org/1811.00995.pdf
    """

    def __init__(
        self,
        output_size: int,
        lipschitz_coeff: float,
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        name: Optional[str] = None,
    ):
        """Constructs the SpectralNormalizedLinear module.
        Args:
          output_size: Output dimensionality.
          lipschitz_coeff: Spectral normalization coefficient.
          with_bias: Whether to add a bias to the output.
          w_init: Optional initializer for weights. By default, uses random values
            from truncated normal, with stddev ``1 / sqrt(fan_in)``. See
            https://arxiv.org/abs/1502.03167v3.
          b_init: Optional initializer for bias. By default, zero.
          name: Name of the module.
        """
        super().__init__(name=name)
        self.input_size = None
        self.output_size = output_size
        self.with_bias = with_bias
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros
        self.lipschitz_coeff = lipschitz_coeff
        self.num_iterations = 100
        self.eps = 1e-6

    def get_normalized_weights(
        self, weights: jnp.ndarray, renormalize: bool = False
    ) -> jnp.ndarray:
        def _l2_normalize(x, axis=None, eps=1e-12):
            return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)

        output_size = self.output_size
        dtype = weights.dtype
        assert output_size == weights.shape[-1]
        sigma = hk.get_state("sigma", (), init=jnp.ones)
        if renormalize:
            # Power iterations to compute spectral norm V*W*U^T.
            u = hk.get_state(
                "u", (1, output_size), dtype, init=hk.initializers.RandomNormal()
            )
            for _ in range(self.num_iterations):
                v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps)
                u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps)
            u = jax.lax.stop_gradient(u)
            v = jax.lax.stop_gradient(v)
            sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0]
            hk.set_state("u", u)
            hk.set_state("v", v)
            hk.set_state("sigma", sigma)
        factor = jnp.maximum(1, sigma / self.lipschitz_coeff)
        return weights / factor

    def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
        """Computes a linear transform of the input."""
        if not inputs.shape:
            raise ValueError("Input must not be scalar.")

        input_size = self.input_size = inputs.shape[-1]
        output_size = self.output_size
        dtype = inputs.dtype

        w_init = self.w_init
        if w_init is None:
            stddev = 1.0 / np.sqrt(self.input_size)
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
        w = self.get_normalized_weights(w, renormalize=True)

        out = jnp.dot(inputs, w)

        if self.with_bias:
            b = hk.get_parameter("b", [self.output_size], dtype, init=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out


class DiscriminatorMLP(hk.Module):
    """A multi-layer perceptron module."""

    def __init__(
        self,
        hidden_layer_sizes: Iterable[int],
        w_init: hk.initializers.Initializer = hk.initializers.RandomNormal(0.05),
        b_init: hk.initializers.Initializer = hk.initializers.RandomNormal(0.05),
        with_bias: bool = True,
        activation: Callable[[jnp.ndarray], jnp.ndarray] = jnp.tanh,
        input_dropout_rate: float = 0.0,
        hidden_dropout_rate: float = 0.0,
        spectral_normalization_lipschitz_coeff: Optional[float] = None,
        name: Optional[str] = None,
    ):
        """Constructs an MLP.
        Args:
          hidden_layer_sizes: Hiddent layer sizes.
          w_init: Initializer for :class:`~haiku.Linear` weights.
          b_init: Initializer for :class:`~haiku.Linear` bias. Must be ``None`` if
            ``with_bias=False``.
          with_bias: Whether or not to apply a bias in each layer.
          activation: Activation function to apply between :class:`~haiku.Linear`
            layers. Defaults to ReLU.
          input_dropout_rate: Dropout on the input.
          hidden_dropout_rate: Dropout on the hidden layer outputs.
          spectral_normalization_lipschitz_coeff: If not None, the network will have
            spectral normalization with the given constant.
          name: Optional name for this module.
        Raises:
          ValueError: If ``with_bias`` is ``False`` and ``b_init`` is not ``None``.
        """
        if not with_bias and b_init is not None:
            raise ValueError("When with_bias=False b_init must not be set.")

        super().__init__(name=name)
        self._activation = activation
        self._input_dropout_rate = input_dropout_rate
        self._hidden_dropout_rate = hidden_dropout_rate
        layer_sizes = list(hidden_layer_sizes) + [1]

        if spectral_normalization_lipschitz_coeff is not None:
            layer_lipschitz_coeff = np.power(
                spectral_normalization_lipschitz_coeff, 1.0 / len(layer_sizes)
            )
            layer_module = functools.partial(
                SpectralNormalizedLinear,
                lipschitz_coeff=layer_lipschitz_coeff,
                w_init=w_init,
                b_init=b_init,
                with_bias=with_bias,
            )
        else:
            layer_module = functools.partial(
                hk.Linear, w_init=w_init, b_init=b_init, with_bias=with_bias
            )

        layers = []
        for index, output_size in enumerate(layer_sizes):
            layers.append(layer_module(output_size=output_size, name=f"linear_{index}"))
        self._layers = tuple(layers)

    def __call__(
        self,
        inputs: jnp.ndarray,
    ):
        out = inputs
        for i, layer in enumerate(self._layers):
            out = layer(out)
            if i < len(self._layers) - 1:
                out = self._activation(out)
        return jax.nn.sigmoid(out)


def _binary_cross_entropy_loss(logit: jnp.ndarray, label: jnp.ndarray) -> jnp.ndarray:
    eps = 1e-6
    return -(label * jnp.log(eps + logit) + (1.0 - label) * jnp.log(eps + 1.0 - logit))


class State(NamedTuple):

    opt_state: optax.OptState
    state: dict
    params: hk.Params


def train_classifier(x, y, n_iters=1000):
    def discriminator(*args, **kwargs):
        return DiscriminatorMLP(
            hidden_layer_sizes=[1024, 1024, 1024, 1024],
            spectral_normalization_lipschitz_coeff=None,
        )(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator)
    )

    key = jax.random.PRNGKey(0)
    param, model_state = discriminator_transformed.init(key, x[0, :])
    opt = optax.adam(1e-4)
    opt_state = opt.init(param)

    state = State(opt_state, model_state, param)

    def loss_fn(params, state, x, y):
        y_pred, model_state = discriminator_transformed.apply(params, state, x)
        return _binary_cross_entropy_loss(y_pred, y).mean(), model_state

    def step(state, x, y):
        params = state.params
        v, grad = jax.value_and_grad(loss_fn, has_aux=True)(params, state.state, x, y)
        loss_value, model_state = v
        updates, opt_state = opt.update(grad, state.opt_state)
        params = optax.apply_updates(state.params, updates)
        return (
            State(
                opt_state,
                model_state,
                params,
            ),
            loss_value,
        )

    step = jax.jit(step)

    values = []
    for _ in tqdm(range(n_iters)):
        state, value = step(state, x, y)
        values += [value]

    def classifier(x: jnp.ndarray) -> jnp.ndarray:
        return discriminator_transformed.apply(state.params, state.state, x)[0]

    return classifier


def main():

    plt.rc("text", usetex=True)
    plt.rc("font", family="serif", size=9)

    np.random.seed(1)

    fig, axs = plt.subplots(1, 3, figsize=(6.6, 1.3))

    # define reward

    def function(s):
        return jnp.sin(2.1 * s) + 0.3 * jnp.cos(3.8 * s + 0.1)

    def reward(s, a):
        return -0.1 * np.abs(a - function(s)) - 0.1 * s**2

    n = 150
    s = jnp.linspace(-1.0, 1.0, n)
    a = jnp.linspace(-1.5, 1.5, n)
    S, A = jnp.meshgrid(s, a)

    R = reward(S, A)

    cnt = axs[0].contourf(S, A, R, levels=100, rasterized=True)
    for c in cnt.collections:
        c.set_edgecolor("face")

    # define samples
    n_exp = 30
    s_exp = np.random.uniform(low=-0.5, high=0.5, size=(n_exp, 1))
    a_exp = function(s_exp)

    n_onl = 30
    s_onl = np.random.uniform(low=-1.0, high=1.0, size=(n_onl, 1))
    a_onl = np.random.uniform(low=-1.5, high=1.5, size=(n_onl, 1))

    # MLP
    x_exp = np.concatenate((s_exp, a_exp), axis=1)
    x_onl = np.concatenate((s_onl, a_onl), axis=1)
    X = np.concatenate((x_exp, x_onl), axis=0)
    Y = np.concatenate((np.ones((n_exp,)), np.zeros((n_onl,))), axis=0)

    classifier = train_classifier(X, Y)

    inputs = np.concatenate((S.ravel()[:, None], A.ravel()[:, None]), axis=1)
    outputs = classifier(inputs)
    rewards = -jnp.log(1.0 - outputs)
    R_ = rewards.reshape((n, n))
    cnt = axs[1].contourf(S, A, R_, levels=100, rasterized=True)
    for c in cnt.collections:
        c.set_edgecolor("face")
    axs[1].plot(s_onl, a_onl, "mx", markersize=6)
    axs[1].plot(s_exp, a_exp, "cx", markersize=6)

    # csil + gaussian process policy
    kernel = RBF() + WhiteKernel(noise_level_bounds=(0.1, 100))
    gpr = GaussianProcessRegressor(kernel=kernel, random_state=0).fit(s_exp, a_exp)
    mu, std = gpr.predict(S.ravel()[:, None], return_std=True)
    llh = np.vectorize(lambda x, m, s: norm.pdf(x, loc=m, scale=s))
    llhs = llh(A.ravel(), mu, std) - llh(A.ravel(), np.zeros_like(mu), np.ones_like(mu))
    LLH = llhs.reshape((n, n))
    cnt = axs[2].contourf(S, A, LLH, levels=100, rasterized=True)
    for c in cnt.collections:
        c.set_edgecolor("face")
    axs[2].plot(s_exp, a_exp, "cx", markersize=6)

    axs[0].set_title("Reward")
    axs[1].set_title("Classifier")
    axs[2].set_title("Ours")

    axs[0].set_ylabel("$a$")

    for ax in axs:
        ax.set_xlim(-1, 1)
        ax.set_xticklabels([])
        ax.set_xticks([])
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xlabel("$s$")

    fig.tight_layout()
    fig.savefig("pull_figure.pdf", bbox_inches="tight")


if __name__ == "__main__":
    main()
    plt.show()
